import torch
import ntk_utils
import pandas as pd
from pylab import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet',
					help='model architecture')
parser.add_argument('--dataset', default='cifar10', type=str,
					help='which dataset used to train')

parser.add_argument('--exp_base', default='te', type=str,
					help='exp name')

parser.add_argument('--if_bn', default=0, type=int,
					help='if use bn')

parser.add_argument('--test', default=0, type=int,
					help='if on test set')
args = parser.parse_args()


if args.if_bn == 0:
	methods = ['normal_nobn', 'fgsm_of_nobn', 'pgd_nobn', 'te_nobn']
else:
	methods = ['normal', 'fgsm_of', 'pgd', 'te']


epoch_list = [i for i in range(1,10)] + [i for i in range(10, 201, 10)]

print(epoch_list)

if args.test == 0:
	matrix_ae_clean_path = './%s/%s/' % (args.dataset, args.arch) + '%s/matrix_ae_clean%d.pt'
	matrix_ae_pgd_path = './%s/%s/' % (args.dataset, args.arch) + '%s/matrix_ae_pgd%d.pt'
else:
	matrix_ae_clean_path = './%s/%s/'%(args.dataset, args.arch) + '%s/matrix_ae_clean%d_test.pt'
	matrix_ae_pgd_path = './%s/%s/'%(args.dataset, args.arch) + '%s/matrix_ae_pgd%d_test.pt'


ae_clean_distance_list = {}

ae_pgd_distance_list = {}

for m in methods:
	ae_clean_distance_list[m] = []
	ae_pgd_distance_list[m] = []

for i in range(len(epoch_list)):
	for m in methods:
		epoch = epoch_list[i]

		matrix_ae_clean = torch.load(matrix_ae_clean_path%(args.exp_base,epoch)).float()
		matrix_ae_pgd = torch.load(matrix_ae_pgd_path%(args.exp_base,epoch)).float()

		matrix_ae_clean_2 = torch.load(matrix_ae_clean_path%(m,epoch))
		matrix_ae_pgd_2 = torch.load(matrix_ae_pgd_path%(m,epoch))

		class_matrix_ae_clean = ntk_utils.calculate_class_average_matrix(matrix_ae_clean)
		class_matrix_ae_pgd = ntk_utils.calculate_class_average_matrix(matrix_ae_pgd)

		class_matrix_ae_clean_2 = ntk_utils.calculate_class_average_matrix(matrix_ae_clean_2)
		class_matrix_ae_pgd_2 = ntk_utils.calculate_class_average_matrix(matrix_ae_pgd_2)

		ae_clean_distance_list[m].append(ntk_utils.cal_kernel_distance(class_matrix_ae_clean.numpy(), class_matrix_ae_clean_2.numpy()))
		ae_pgd_distance_list[m].append(ntk_utils.cal_kernel_distance(class_matrix_ae_pgd.numpy(), class_matrix_ae_pgd_2.numpy()))


print('ae_vs_clean Dis:', ae_clean_distance_list)
print('ae_vs_pgd Dis:', ae_pgd_distance_list)

print('\n')

if args.test == 0:
	matrix_ae_clean_path = './%s/%s/' % (args.dataset, args.arch) + '%s/ae_clean%s.pt'
	matrix_ae_pgd_path = './%s/%s/' % (args.dataset, args.arch) + '%s/ae_pgd%s.pt'
else:
	matrix_ae_clean_path = './%s/%s/'%(args.dataset, args.arch) + '%s/ae_clean%s_test.pt'
	matrix_ae_pgd_path = './%s/%s/'%(args.dataset, args.arch) + '%s/ae_pgd%s_test.pt'


torch.save(ae_clean_distance_list, matrix_ae_clean_path%(args.exp_base,'cross_dist_dict'))
torch.save(ae_pgd_distance_list, matrix_ae_pgd_path%(args.exp_base,'cross_dist_dict'))









